#include "header.hpp"
#include "model.hpp"
#include "tensor.hpp"

#ifndef __NN_LAYER_H__
#define __NN_LAYER_H__
namespace nn {
enum LayerType{
    liner,
    activation
};
class LayerBase {
protected:
  string _name;
  Dmatrix _input;
  Dmatrix _grad;
  Dtensor _grad_b;
  Dtensor _bias;
  Dmatrix _weight;
public:
  LayerBase(string name = "-") {
    assignable_constraint(name);
    _name = name;
    // _input = Dmatrix::Zero(1, _in_features);
  };
  ~LayerBase(){};
  void set_layer_name(const string &name) noexcept { _name = name; }
  const string &get_layer_name(void) noexcept { return _name; }
  const Dmatrix &get_input(void) noexcept { return _input; }
  void set_input(const Dmatrix &x) noexcept { _input = x; }
  Dmatrix &get_grad() { return _grad; }
  Dtensor &get_grad_b() {return  _grad_b;}
  void set_grad(const Dmatrix & oth) {_grad = oth;}
  const Dmatrix &get_onlyread_grad() { return _grad; }
  virtual Dmatrix forward(const Dmatrix &x) = 0;
  virtual Dmatrix backward(const Dmatrix &grad) = 0;
  virtual LayerType layerTp() = 0;
const Dmatrix &get_w(void) noexcept { return _weight; }
void set_w(const Dmatrix &w) noexcept { _weight = w; }
void set_b(const Dmatrix &b) noexcept {_bias = b;}
const Dtensor& get_b() noexcept {return _bias;}

protected:
}; // class LayerBase

class LinerBase : public LayerBase {
protected:
  int _in_features;
  int _out_features;

  bool _bias_enable;


public:
  LinerBase(int in_features, int out_freatures, const string name = "LinerBase",
            bool use_bias = true, double bias = -1) noexcept
      : _in_features(in_features), _out_features(out_freatures),
        _bias_enable(use_bias), LayerBase(name) {

    set_w(Dmatrix::Random(in_features, out_freatures));
    set_b(Dtensor::Ones(out_freatures));
    Dmatrix &grad = get_grad();
    grad = Dtensor::Zero(_out_features);
  };

    LayerType layerTp() noexcept override {return liner;}

  ~LinerBase(){};

protected:
  void set_bias_enable(bool be) { _bias_enable = be; }
  bool get_bias_enable() { return _bias_enable; }

}; // class Layer
class Dense : public LinerBase {
protected:
public:
  Dense(int in_features, int out_freatures, const string name = "Dense")
      : LinerBase(in_features, out_freatures, name){};
  ~Dense(){};

  Dmatrix forward(const Dmatrix &x) /*noexcept*/ override  {
    set_input(x);
    Dmatrix y = x * get_w();
    if (_bias_enable) y += _bias;
    return y;
  }

  Dmatrix backward(const Dmatrix &grad) noexcept override  {
    const Dmatrix &input = get_input();
    if (_bias_enable) {
        auto _b = get_grad_b();
        _b = grad.rowwise().sum();
    }
    set_grad(input.transpose() * grad);
    return grad * get_w().transpose();
  }
}; // class Dense

namespace Activation {
using string = std::string;
class ActivationBase : public LayerBase {
public:
  ActivationBase(const string name = "Activation") : LayerBase(name){};

  inline Dmatrix forward(const Dmatrix &x) noexcept {
    set_input(x);
    set_grad(Dtensor::Zero(x.cols()));
    return cacl(x);
  }

  inline Dmatrix backward(const Dmatrix &grad) noexcept {
    return (derivative_cacl(get_input()).array() * grad.array());
  }
  LayerType layerTp() noexcept override {return activation;}
  virtual Dmatrix cacl(const Dmatrix &x) = 0;
  virtual Dmatrix derivative_cacl(const Dmatrix &x) = 0;
};

class ReLU : public ActivationBase {
public:
  ReLU(const string name = "ReLU") noexcept : ActivationBase(name){};
  inline Dmatrix cacl(const Dmatrix &x) noexcept { return x.cwiseMax(0.0); }

  Dmatrix derivative_cacl(const Dmatrix &x) {
    Dmatrix grad = x;
    for (auto i : grad.reshaped()) {
      i = i > 0.0;
    }
    return grad;
  }
};
} // namespace Activation

class Conv2d : private LayerBase {
private:
public:
  Conv2d(const string name = "Conv2d") : LayerBase(name){};
  ~Conv2d(){};
  Dmatrix forward(Dtensor x) {return Dmatrix(1,1);}
};
using Layerlist = std::vector<LayerBase *>; // layer list
} // namespace nn
//#endif